{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Symbolic Computation Speedups"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is a simple example of how symbolic computation can yield large reductions in the amount of computation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Setup\n",
    "import symforce\n",
    "\n",
    "symforce.set_symbolic_api(\"sympy\")\n",
    "\n",
    "import symforce.symbolic as sf\n",
    "from symforce.notebook_util import display"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Autodiff vs Symbolic diff"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Multiplying three matrices (for example to chain rule Jacobians)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "A = sf.Matrix.zeros(5, 5).symbolic(\"A\")\n",
    "B = sf.Matrix.zeros(5, 3).symbolic(\"B\")\n",
    "C = sf.Matrix.zeros(3, 3).symbolic(\"C\")\n",
    "display(A, B, C)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "display(A * B * C)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Normal matrix multiplication (can happen with SIMD)\n",
    "sf.count_ops(A * B * C)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### The same with typical sparsity patterns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "A = sf.Matrix.diag(sf.symbols(\"A:5\"))\n",
    "A[0, 4] = sf.Symbol(\"A5\")\n",
    "B = sf.Matrix.zeros(5, 3)\n",
    "B[:3, :3] = sf.Matrix.diag(sf.symbols(\"B:3\"))\n",
    "B[3, 0] = sf.Symbol(\"B3\")\n",
    "B[3, 1] = sf.Symbol(\"B3\")\n",
    "B[3, 2] = sf.Symbol(\"B3\")\n",
    "B[4, 1] = sf.Symbol(\"B4\")\n",
    "C = sf.M(sf.Rot3.hat(sf.symbols(\"C:3\")))\n",
    "display(A, B, C)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "display(A * B * C)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Direct computation of result\n",
    "sf.count_ops(A * B * C)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Looking at the sub-expressions\n",
    "intermediates, output = sf.cse(A * B * C)\n",
    "\n",
    "num_operations = 0\n",
    "for lhs, rhs in intermediates:\n",
    "    display(sf.sympy.Eq(lhs, rhs))\n",
    "    num_operations += sf.count_ops(rhs)\n",
    "\n",
    "output_mat = sf.Matrix53(output)\n",
    "num_operations += sf.count_ops(output_mat)\n",
    "display(output_mat)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# With common sub-expression elimination\n",
    "display(num_operations)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Code generation\n",
    "for lhs, rhs in intermediates:\n",
    "    print(f\"float {lhs} = {rhs};\")\n",
    "\n",
    "print(\"\")\n",
    "for i, out in enumerate(output):\n",
    "    print(f\"out[{i}] = {out};\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Further points:\n",
    "\n",
    "* simple example, entries are leaves\n",
    "* only three matrices deep\n",
    "* applies to other domains, not just derivatives + optimization\n",
    "* not considering dynamic allocation (matrices on the heap), pointer chasing\n",
    "* not considering you had to write the jacobians in the first place, test, debug"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
